{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "name": "Foggy_CycleGAN.ipynb",
   "provenance": [],
   "collapsed_sections": [],
   "toc_visible": true,
   "include_colab_link": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "view-in-github",
    "colab_type": "text"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/ghaiszaher/Foggy-CycleGAN/blob/master/Foggy_CycleGAN.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "_xnMOsbqHz61"
   },
   "source": [
    "# Foggy-CycleGAN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "e1_Y75QXJS6h"
   },
   "source": [
    "## Set up the input pipeline"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "lhSsUx9Nyb3t",
    "colab": {}
   },
   "source": [
    "import sys\n",
    "colab = 'google.colab' in sys.modules\n",
    "if colab:\n",
    "    # noinspection PyBroadException\n",
    "    try:\n",
    "        %tensorflow_version 2.x\n",
    "    except Exception:\n",
    "        pass\n",
    "import tensorflow as tf"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "gct0xca4KwM6",
    "pycharm": {
     "name": "#%%\n"
    },
    "colab": {}
   },
   "source": [
    "# noinspection PyUnresolvedReferences\n",
    "print(tf.__version__)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "YfIk2es3hJEd",
    "colab": {}
   },
   "source": [
    "import tensorflow_datasets as tfds\n",
    "\n",
    "import os\n",
    "from IPython.display import clear_output\n",
    "\n",
    "tfds.disable_progress_bar()"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "0Y14JF3xcXWB",
    "colab": {}
   },
   "source": [
    "if colab:\n",
    "    os.environ['PROJECT_DIR'] = project_dir = '/content/Foggy-CycleGAN'\n",
    "    replace = True\n",
    "    if os.path.isdir(project_dir):\n",
    "        choice = input(\"Project already exists in folder \"+\n",
    "              \"{}\\nDelete the files and pull again? Enter Y/(N):\\n\"\n",
    "              .format(project_dir))\n",
    "        if choice.lower()=='y':\n",
    "            !rm -r $PROJECT_DIR\n",
    "            print(\"Deleted folder {}\".format(project_dir))\n",
    "        else:\n",
    "            replace = False\n",
    "            print(\"Nothing was changed.\")\n",
    "    if replace:\n",
    "        !cd /content && git clone https://github.com/ghaiszaher/Foggy-CycleGAN.git\n",
    "        print(\"Project cloned to \" + project_dir)\n",
    "    os.chdir(project_dir)\n",
    "    print(\"Done.\")"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "iYn4MdZnKCey"
   },
   "source": [
    "## Prepare Datasets"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "2CbTEt448b4R",
    "colab": {}
   },
   "source": [
    "BATCH_SIZE = 5 if colab else 1\n",
    "IMG_WIDTH = 256\n",
    "IMG_HEIGHT = 256"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "Qu5_c2ecxZBF",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "project_label = \"\" #@param {type:\"string\"}"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "mount_path = None #to suppress warnings\n",
    "drive_project_path = None\n",
    "if colab:\n",
    "    # noinspection PyUnresolvedReferences\n",
    "    from google.colab import drive\n",
    "    mount_path = '/content/drive'\n",
    "    drive.mount(mount_path, force_remount=True)\n",
    "    drive_project_path = os.path.join(mount_path,\"My Drive/Colab Notebooks/Foggy-CycleGAN/\",project_label)\n",
    "    drive_datasets_path = os.path.join(mount_path,\"My Drive/Colab Notebooks/Datasets/\")\n",
    "    os.environ['DRIVE_PROJECT'] = drive_project_path\n",
    "    os.environ['DRIVE_DATASETS'] = drive_datasets_path"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "95a1NFSv6Qfa",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "#Unzip dataset from Google Drive to /content/dataset/ folder\n",
    "if colab:\n",
    "    !sh $PROJECT_DIR/copy_dataset.sh"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "id": "-7Za2DrexGpg",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "test_split = 0.2 #@param {type:\"slider\", min:0.05, max:0.95, step:0.05}"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from lib.dataset import DatasetInitializer\n",
    "\n",
    "datasetInit = DatasetInitializer(256, 256)\n",
    "datasetInit.dataset_path = '/content/dataset/' if colab else  './dataset/'\n",
    "(train_clear, train_fog), (test_clear, test_fog), (sample_clear, sample_fog) = datasetInit.prepare_dataset(\n",
    "    BATCH_SIZE,\n",
    "    test_split=test_split,\n",
    "    random_seed=7)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hvX8sKsfMaio",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Build Generator"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "8ju9Wyw87MRW",
    "colab": {}
   },
   "source": [
    "from lib.models import ModelsBuilder\n",
    "OUTPUT_CHANNELS = 3\n",
    "models_builder = ModelsBuilder()"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "1PlWxBqiHu1M",
    "colab": {}
   },
   "source": [
    "use_transmission_map = False #@param{type: \"boolean\"}\n",
    "use_gauss_filter = False #@param{type: \"boolean\"}\n",
    "use_resize_conv = False #@param{type: \"boolean\"}\n",
    "\n",
    "generator_clear2fog = models_builder.build_generator(use_transmission_map=use_transmission_map,\n",
    "                                                     use_gauss_filter=use_gauss_filter,\n",
    "                                                     use_resize_conv=use_resize_conv)\n",
    "generator_fog2clear = models_builder.build_generator(use_transmission_map=False)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "ZEiYYoqtKwOA",
    "pycharm": {
     "name": "#%%\n"
    },
    "colab": {}
   },
   "source": [
    "tf.keras.utils.plot_model(generator_clear2fog, show_shapes=True, dpi=64, to_file='generator_clear2fog.png');"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "YdQXzKvjKwON",
    "pycharm": {
     "name": "#%%\n"
    },
    "colab": {}
   },
   "source": [
    "tf.keras.utils.plot_model(generator_fog2clear, show_shapes=True, dpi=64, to_file='generator_fog2clear.png');"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "FsCfIhjeIDGM",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Build Discriminator"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "eSZb1ZsOIbhj",
    "colab": {}
   },
   "source": [
    "use_intensity_for_fog_discriminator = False #@param{type: \"boolean\"}\n",
    "discriminator_fog = models_builder.build_discriminator(use_intensity=use_intensity_for_fog_discriminator)\n",
    "discriminator_clear = models_builder.build_discriminator(use_intensity=False)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "lAp1Ob5ZMAKe",
    "colab": {}
   },
   "source": [
    "tf.keras.utils.plot_model(discriminator_fog, show_shapes=True, dpi=64, to_file=\"discriminator_fog.png\");\n",
    "tf.keras.utils.plot_model(discriminator_clear, show_shapes=True, dpi=64, to_file=\"discriminator_clear.png\");"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "aKUZnDiqQrAh",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "WJnftd5sQsv6",
    "colab": {}
   },
   "source": [
    "if colab:\n",
    "    weights_path = os.path.join(drive_project_path, 'weights/')\n",
    "else:\n",
    "    weights_path = \"./weights/\""
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "E5qke6pCSeoK",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "from lib.train import Trainer\n",
    "trainer = Trainer(generator_clear2fog, generator_fog2clear,\n",
    "                 discriminator_fog, discriminator_clear)\n",
    "\n",
    "trainer.configure_checkpoint(weights_path = weights_path, load_optimizers=False)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "wDaGZ3WpZUyw",
    "pycharm": {
     "name": "#%%\n"
    },
    "colab": {}
   },
   "source": [
    "from lib.plot import plot_generators_predictions\n",
    "for clear, fog in tf.data.Dataset.zip((sample_clear.take(1), sample_fog.take(1))):\n",
    "    plot_generators_predictions(generator_clear2fog, clear, generator_fog2clear, fog)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "O5MhJmxyZiy9",
    "pycharm": {
     "name": "#%%\n"
    },
    "colab": {}
   },
   "source": [
    "from lib.plot import plot_discriminators_predictions\n",
    "for clear, fog in tf.data.Dataset.zip((sample_clear.take(1), sample_fog.take(1))):\n",
    "    plot_discriminators_predictions(discriminator_clear, clear, discriminator_fog, fog, use_intensity_for_fog_discriminator)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Rw1fkAczTQYh"
   },
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "id": "UaubWnbRxGrh",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "use_tensorboard = True #@param{type:\"boolean\"}"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "id": "TndgaMHnxGrs",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "if use_tensorboard:\n",
    "    import tensorboard\n",
    "    tb = tensorboard.program.TensorBoard()\n",
    "    if colab:\n",
    "        trainer.tensorboard_base_logdir = os.path.join(drive_project_path,\"tensorboard_logs/\")\n",
    "    tb.configure(argv=[None, '--logdir', trainer.tensorboard_base_logdir])\n",
    "    url = tb.launch()\n",
    "    if colab:\n",
    "        tensorboard.notebook.display(port=6006, height=1000)\n",
    "    else:\n",
    "        print(url)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "id": "5LeVzPCaxGrz",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "if colab:\n",
    "    trainer.image_log_path = os.path.join(drive_project_path,\"image_logs/\")\n",
    "    trainer.config_path  = os.path.join(drive_project_path,\"trainer_config.json\")"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "id": "ie_Z7WkRxGr5",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "trainer.load_config()"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "NS2GWywBbAWo",
    "colab": {}
   },
   "source": [
    "use_transmission_map_loss=True #@param{type: \"boolean\"}\n",
    "use_whitening_loss=True #@param{type: \"boolean\"}\n",
    "use_rgb_ratio_loss=True #@param{type: \"boolean\"}\n",
    "save_optimizers=False #@param{type: \"boolean\"}\n",
    "\n",
    "trainer.train(\n",
    "    train_clear, train_fog,\n",
    "    epochs=100,\n",
    "    clear_output_callback=lambda: clear_output(wait=True),\n",
    "    use_tensorboard=use_tensorboard,\n",
    "    sample_test=(sample_clear, sample_fog),\n",
    "    load_config_first=False,\n",
    "    use_transmission_map_loss=use_transmission_map_loss,\n",
    "    use_whitening_loss=use_whitening_loss,\n",
    "    use_rgb_ratio_loss=use_rgb_ratio_loss,\n",
    "    save_optimizers=save_optimizers,\n",
    "    use_intensity_for_fog_discriminator=use_intensity_for_fog_discriminator\n",
    ")"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "id": "bmA0Q6ZxxGsG",
    "colab_type": "text"
   },
   "source": [
    "## Testing"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "colab_type": "code",
    "id": "NLnOk_xqvFNm",
    "pycharm": {
     "name": "#%%\n"
    },
    "colab": {}
   },
   "source": [
    "# TODO: store predictions\n",
    "for clear, fog in zip(test_clear.take(5), test_fog.take(5)):\n",
    "    plot_generators_predictions(generator_clear2fog, clear, generator_fog2clear, fog)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "for clear, fog in zip(sample_clear, sample_fog):\n",
    "    plot_generators_predictions(generator_clear2fog, clear, generator_fog2clear, fog)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from lib.plot import plot_clear2fog_intensity\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "intensity_path = './intensity/'\n",
    "from lib.tools import create_dir\n",
    "create_dir(intensity_path)\n",
    "\n",
    "image_clear = next(iter(test_clear))[0][0]\n",
    "step = 0.05\n",
    "for (ind, i) in enumerate(tf.range(0,1+step, step)):\n",
    "    fig = plot_clear2fog_intensity(generator_clear2fog, image_clear, i)\n",
    "    fig.savefig(os.path.join(intensity_path\n",
    "                             , \"{:02d}_intensity_{:0.2f}.jpg\".format(ind,i)), bbox_inches='tight', pad_inches=0)\n",
    "    if colab:\n",
    "        plt.show()\n",
    "    else:\n",
    "        plt.close(fig)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "if colab:\n",
    "    !cd ./intensity; zip /content/intensity.zip *"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Testing Custom images\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from lib.plot import plot_clear2fog_intensity\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "intensity_path = './intensity/'\n",
    "from lib.tools import create_dir\n",
    "create_dir(intensity_path)\n",
    "file_path = 'E:/Downloads/test-image.png'\n",
    "\n",
    "image_clear = tf.io.decode_png(tf.io.read_file(file_path), channels=3)\n",
    "image_clear, _ = datasetInit.preprocess_image_test(image_clear, 0)\n",
    "step = 0.05\n",
    "for (ind, i) in enumerate(tf.range(0,1+step, step)):\n",
    "    fig = plot_clear2fog_intensity(generator_clear2fog, image_clear, i)\n",
    "    fig.savefig(os.path.join(intensity_path\n",
    "                             , \"{:02d}_intensity_{:0.2f}.jpg\".format(ind,i)), bbox_inches='tight', pad_inches=0)\n",
    "    if colab:\n",
    "        plt.show()\n",
    "    else:\n",
    "        plt.close(fig)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "if colab:\n",
    "    !cd ./intensity; zip /content/intensity.zip *"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ]
}